# Creative Commons Zero v1.0 Universal
# SPDX-License-Identifier: CC0-1.0
# Created by Douglas Stebila

from frodokem import FrodoKEM

class NISTKAT(object):
    
    @staticmethod
    def run(kem):
        """Generate a NIST KAT output for the given KEM"""
        rng = NISTKAT.NISTRNG()
        kem.randombytes = rng.randombytes
        print("count = 0")
        print("seed = <unspecified>")
        (pk, sk) = kem.kem_keygen()
        print("pk =", pk.hex().upper())
        print("sk =", sk.hex().upper())
        (ct, ss_e) = kem.kem_encaps(pk)
        print("ct =", ct.hex().upper())
        print("ss =", ss_e.hex().upper())
        ss_d = kem.kem_decaps(sk, ct)
        assert ss_e.hex() == ss_d.hex(), "Shared secrets not equal"

    class NISTRNG(object):
        """Dummy object that contains a serialization of the randombytes outputs that 
        would be generated by the deterministic RNG in the NIST KAT generation algorithm 
        when run on FrodoKEM; note that this is not a proper implementation of that RNG, 
        it has the exact values generated hard-coded."""

        def __init__(self):
            self.current_outputs = None
            self.possible_outputs = [
                [
                    bytes.fromhex('7C9935A0B07694AA0C6D10E4DB6B1ADD2FD81A25CCB148032DCD739936737F2DB505D7CFAD1B497499323C8686325E47'),
                    bytes.fromhex('33B3C07507E4201748494D832B6EE2A6'),
                ],
                [
                    bytes.fromhex('7C9935A0B07694AA0C6D10E4DB6B1ADD2FD81A25CCB148032DCD739936737F2DB505D7CFAD1B497499323C8686325E4792F267AAFA3F87CA60D01CB54F29202A'),
                    bytes.fromhex('EB4A7C66EF4EBA2DDB38C88D8BC706B1D639002198172A7B'),
                ],
                [
                    bytes.fromhex('7C9935A0B07694AA0C6D10E4DB6B1ADD2FD81A25CCB148032DCD739936737F2DB505D7CFAD1B497499323C8686325E4792F267AAFA3F87CA60D01CB54F29202A3E784CCB7EBCDCFD45542B7F6AF77874'),
                    bytes.fromhex('8BF0F459F0FB3EA8D32764C259AE631178976BAF3683D33383188A65A4C2449B'),
                ],
            ]
        
        def randombytes(self, n):
            # On each use, check if the output length requested matches one of the pre-programmed values
            if self.current_outputs == None:
                for i in range(len(self.possible_outputs)):
                    if n == len(self.possible_outputs[i][0]):
                        self.current_outputs = self.possible_outputs[i]
                assert self.current_outputs != None, "Could not find pre-programmed output of desired length {:d}".format(n)
            assert len(self.current_outputs) > 0, "No remaining outputs available"
            assert n == len(self.current_outputs[0]), "Next output not of the appropriate length"
            r = self.current_outputs.pop(0)
            return r

if __name__ == "__main__":
    # Run KATs for all supported FrodoKEM variants
    NISTKAT.run(FrodoKEM('FrodoKEM-640-AES'))
    NISTKAT.run(FrodoKEM('FrodoKEM-640-SHAKE'))
    NISTKAT.run(FrodoKEM('FrodoKEM-976-AES'))
    NISTKAT.run(FrodoKEM('FrodoKEM-976-SHAKE'))
    NISTKAT.run(FrodoKEM('FrodoKEM-1344-AES'))
    NISTKAT.run(FrodoKEM('FrodoKEM-1344-SHAKE'))
